Skip to main content

Flex Checkpoint工作记录

1. Flex Checkpoint关键组件

1.1 reshard_sharded_state_dict

def reshard_sharded_state_dict(
src_sharded_state_dict: ShardedStateDict,
dst_sharded_state_dict: ShardedStateDict,
process_group: Group,
coordinator_rank: int | None = 0,
offload: bool | None = False,
aoa_config: dist[str, list[str]] | None = None,
) -> None:

local_src_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in src_sharded_state_dict.items()
}

global_src_state_dict_shard_info = []
dist.all_gather_object(
global_src_state_dict_shard_info,
local_src_state_dict_shard_info,
group=process_group,
)

src_state_dict_shard_info = {}
for rank_shard_info in global_src_state_dict_shard_info:
for key, tensor_shard_info in rank_shard_info.items():
if key not in src_state_dict_shard_info:
src_state_dict_shard_info[key] = []
src_state_dict_shard_info[key].append(tensor_shard_info)

# check validity
check_src_state_dict_validity(src_state_dict_shard_info)

local_dst_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in dst_sharded_state_dict.items()
}

global_dst_state_dict_shard_info = []
dist.all_gather_object(
global_dst_state_dict_shard_info,
local_dst_state_dict_shard_info,
group=process_group,
)

dst_state_dict_shard_info = {}
for rank_shard_info in global_dst_state_dict_shard_info:
for key, tensor_shard_info in rank_shard_info.items():
if key not in dst_state_dict_shard_info:
dst_state_dict_shard_info[key] = []
dst_state_dict_shard_info[key].append(tensor_shard_info)

# check validity
check_dst_state_dict_validity(dst_state_dict_shard_info)
check_src_dst_state_dict_validity(
src_state_dict_shard_info, dst_state_dict_shard_info
)

# build metadata
state_dict_metadata = {
tensor_name: [
LocalTensorMetadata(
global_offset=shard_info[0],
local_shape=shard_info[1],
dtype=shard_info[2],
)
for shard_info in shard_infos
]
for tensor_name, shard_infos in src_state_dict_shard_info.items()
}

virtual_file_path = f"vfile_{dist.get_rank()}"
local_storage_metadata = {
LocalTensorIndex(
tensor_key=value.key,
global_offset=value.global_offset,
): virtual_file_path
for key, value in src_sharded_state_dict.items()
}

global_storage_metadata: list[dict[LocalTensorIndex, str]] = []
dist.all_gather_object(
global_storage_metadata,
local_storage_metadata,
group=process_group,
)

# Merge storage metadata
storage_metadata: dict[LocalTensorIndex, str] = {}
for rank_storage_metadata in global_storage_metadata:
storage_metadata.update(rank_storage_metadata)

# Prepare metadata for loading
metadata = Metadata(
state_dict_metadata=state_dict_metadata,
storage_metadata=storage_metadata,
flat_mapping=None,
)

# Extract local tensors
src_state_dict = {
key: value.local_tensor for key, value in src_sharded_state_dict.items()
}
dst_state_dict = dst_sharded_state_dict
# reshard using _load_state_dict
_load_state_dict(
target_state_dict=dst_state_dict,
source_state_dict={virtual_file_path: src_state_dict},
metadata_list=[metadata],
coordinator_rank=coordinator_rank,
process_group=process_group,
offload=offload,
)

​ 这个函数实际是为了构建reshard过程中需要的metadata,实际的reshard操作,在load_state_dict里面。state_dict_metadatastorage_metadata 最终都包含了所有 rank 的分片信息,是全局的完整信息。

​ 这里使用virtual_file_path是因为此时实际的数据已经可以取到,即每个rank上local_tensor的实际值,无需再从文件中读取,这么做是为了整个格式上的对齐。

1.1.1 全局信息的构建过程

state_dict_metadata 的构建,state_dict_metadata用来保存Tensor的全局元数据信息
# 步骤1:每个 rank 收集自己的分片信息
local_src_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in src_sharded_state_dict.items()
}

# 步骤2:全局收集所有 rank 的信息
global_src_state_dict_shard_info = []
dist.all_gather_object(
global_src_state_dict_shard_info,
local_src_state_dict_shard_info,
group=process_group,
)

# 结果:每个 rank 都有所有 rank 的信息
global_src_state_dict_shard_info = [
# rank 0 的信息
{"linear.weight": ((0, 0), (256, 512), "float32", (1024, 512), False)},
# rank 1 的信息
{"linear.weight": ((256, 0), (256, 512), "float32", (1024, 512), False)},
# rank 2 的信息
{"linear.weight": ((512, 0), (256, 512), "float32", (1024, 512), False)},
# rank 3 的信息
{"linear.weight": ((768, 0), (256, 512), "float32", (1024, 512), False)},
]

# 步骤3:重组为按张量分组的全局信息
src_state_dict_shard_info = {
"linear.weight": [
((0, 0), (256, 512), "float32", (1024, 512), False), # rank 0
((256, 0), (256, 512), "float32", (1024, 512), False), # rank 1
((512, 0), (256, 512), "float32", (1024, 512), False), # rank 2
((768, 0), (256, 512), "float32", (1024, 512), False), # rank 3
]
}

# 步骤4:构建全局的 state_dict_metadata
state_dict_metadata = {
"linear.weight": [
LocalTensorMetadata(global_offset=(0, 0), local_shape=(256, 512), dtype="float32"), # rank 0
LocalTensorMetadata(global_offset=(256, 0), local_shape=(256, 512), dtype="float32"), # rank 1
LocalTensorMetadata(global_offset=(512, 0), local_shape=(256, 512), dtype="float32"), # rank 2
LocalTensorMetadata(global_offset=(768, 0), local_shape=(256, 512), dtype="float32"), # rank 3
]
}
storage_metadata 的构建,storage_metadata 用来保存Tensor实际数据保存的位置信息
# 步骤1:每个 rank 构建自己的存储映射
virtual_file_path = f"vfile_{dist.get_rank()}"
local_storage_metadata = {
LocalTensorIndex(
tensor_key=value.key,
global_offset=value.global_offset,
): virtual_file_path
for key, value in src_sharded_state_dict.items()
}

# rank 0 的本地映射
local_storage_metadata = {
LocalTensorIndex("linear.weight", (0, 0)): "vfile_0",
}

# 步骤2:全局收集所有 rank 的存储映射
global_storage_metadata: list[dict[LocalTensorIndex, str]] = []
dist.all_gather_object(
global_storage_metadata,
local_storage_metadata,
group=process_group,
)

# 结果:每个 rank 都有所有 rank 的存储映射
global_storage_metadata = [
# rank 0 的映射
{LocalTensorIndex("linear.weight", (0, 0)): "vfile_0"},
# rank 1 的映射
{LocalTensorIndex("linear.weight", (256, 0)): "vfile_1"},
# rank 2 的映射
{LocalTensorIndex("linear.weight", (512, 0)): "vfile_2"},
# rank 3 的映射
{LocalTensorIndex("linear.weight", (768, 0)): "vfile_3"},
]

# 步骤3:合并为全局的 storage_metadata
storage_metadata: dict[LocalTensorIndex, str] = {}
for rank_storage_metadata in global_storage_metadata:
storage_metadata.update(rank_storage_metadata)

# 最终的全局 storage_metadata
storage_metadata = {
LocalTensorIndex("linear.weight", (0, 0)): "vfile_0", # rank 0
LocalTensorIndex("linear.weight", (256, 0)): "vfile_1", # rank 1
LocalTensorIndex("linear.weight", (512, 0)): "vfile_2", # rank 2
LocalTensorIndex("linear.weight", (768, 0)): "vfile_3", # rank 3
}

1.1.2 为什么需要全局信息?

重分片需要完整的分片信息
# 重分片过程:
# 源:4 个分片 -> 目标:2 个分片

# 需要知道所有源分片的信息才能正确重分片
source_shards = [
((0, 0), (256, 512)), # rank 0
((256, 0), (256, 512)), # rank 1
((512, 0), (256, 512)), # rank 2
((768, 0), (256, 512)), # rank 3
]

# 目标分片需要从多个源分片组合数据
target_shard_0 = combine(source_shards[0], source_shards[1]) # 需要 rank 0 和 rank 1 的数据
target_shard_1 = combine(source_shards[2], source_shards[3]) # 需要 rank 2 和 rank 3 的数据
数据访问需要全局映射
# _load_state_dict 需要知道:
# 1. 每个分片在哪里(storage_metadata)
# 2. 每个分片的形状和位置(state_dict_metadata)

def load_shard(tensor_name, global_offset):
# 根据全局信息找到对应的分片
index = LocalTensorIndex(tensor_name, global_offset)
file_path = storage_metadata[index] # "vfile_0"

# 从对应的数据源获取数据
if file_path in source_state_dict:
return source_state_dict[file_path][tensor_name]
验证需要全局视图
# 验证分片完整性需要全局信息
def validate_completeness():
# 检查是否所有分片都存在
expected_shards = [
(0, 0), (256, 0), (512, 0), (768, 0)
]

for offset in expected_shards:
index = LocalTensorIndex("linear.weight", offset)
if index not in storage_metadata:
raise ValueError(f"Missing shard at {offset}")

1.2 utils相关工具组件总结

1. 索引转换工具

ravel_index(indices, shape)
def ravel_index(indices, shape):
idx = 0
for i, dim in zip(indices, shape):
idx = idx * dim + i
return idx

作用:将多维索引转换为线性索引(行优先顺序)

详细解释

# 例子:shape = (2, 3, 4)
# 多维索引 (1, 2, 3) 转换为线性索引

# 计算过程:
# i=0: idx = 0 * 2 + 1 = 1
# i=1: idx = 1 * 3 + 2 = 5
# i=2: idx = 5 * 4 + 3 = 23

# 结果:线性索引 = 23
# 验证:在2×3×4的张量中,位置(1,2,3)的线性索引确实是23

应用场景

  • 将多维张量的位置转换为内存中的线性地址
  • 在分片计算中定位元素在全局张量中的位置
unravel_index(idx, shape)
def unravel_index(idx, shape):
indices = []
for dim in reversed(shape):
indices.append(idx % dim)
idx //= dim
return tuple(reversed(indices))

作用:将线性索引转换为多维索引

详细解释

# 例子:shape = (2, 3, 4), idx = 23
# 线性索引 23 转换为多维索引

# 计算过程(从右到左):
# dim=4: indices.append(23 % 4 = 3), idx = 23 // 4 = 5
# dim=3: indices.append(5 % 3 = 2), idx = 5 // 3 = 1
# dim=2: indices.append(1 % 2 = 1), idx = 1 // 2 = 0

# 结果:多维索引 = (1, 2, 3)

应用场景

  • 从内存地址恢复多维张量的位置
  • 在分片重建时确定元素在全局张量中的坐标

2. 切片计算工具

minimal_nd_slice(shape, flat_start, flat_end)
def minimal_nd_slice(shape, flat_start, flat_end):
start_idx = unravel_index(flat_start, shape)
end_idx = unravel_index(flat_end - 1, shape)
min_slices = []
for axis in range(len(shape)):
if axis == 0:
s = start_idx[axis]
e = end_idx[axis] + 1
else:
if start_idx[axis - 1] == end_idx[axis - 1]:
s = min(start_idx[axis], end_idx[axis])
e = max(start_idx[axis], end_idx[axis]) + 1
else:
s = 0
e = shape[axis]
min_slices.append((s, e))
return min_slices, start_idx, end_idx

作用:计算包含给定扁平化范围的最小N维切片

详细解释

# 例子:shape = (4, 3), flat_start = 5, flat_end = 8
# 扁平化范围 [5, 8) 转换为最小切片

# 计算过程:
# start_idx = unravel_index(5, (4, 3)) = (1, 2)
# end_idx = unravel_index(7, (4, 3)) = (2, 1)

# 对于axis=0:
# s = 1, e = 2 + 1 = 3

# 对于axis=1:
# start_idx[0] = 1, end_idx[0] = 2, 不相等
# 所以 s = 0, e = 3

# 结果:min_slices = [(1, 3), (0, 3)]
# 这表示需要切片 [1:3, 0:3]

应用场景

  • 将扁平化的索引范围转换为最优的多维切片
  • 减少数据传输量,提高效率
flat_range_in_min_slice(shape, min_slices, flat_start, flat_end)
def flat_range_in_min_slice(shape, min_slices, flat_start, flat_end):
min_starts = tuple(s[0] for s in min_slices)
min_flat_start = ravel_index(min_starts, shape)
return flat_start - min_flat_start, flat_end - min_flat_start

作用:计算在最小切片中的相对扁平化范围

详细解释

# 例子:shape = (4, 3), min_slices = [(1, 3), (0, 3)]
# flat_start = 5, flat_end = 8

# 计算过程:
# min_starts = (1, 0)
# min_flat_start = ravel_index((1, 0), (4, 3)) = 3
# 相对范围 = (5 - 3, 8 - 3) = (2, 5)

# 这表示在最小切片内的相对位置

应用场景

  • 计算在切片内的相对偏移
  • 用于精确的数据提取和复制

3. 状态字典检查工具

is_sharded_state_dict(o)
def is_sharded_state_dict(o):
if not isinstance(o, dict):
return False

values = list(o.values())
has_sharded_weight = any(isinstance(v, ShardedWeight) for v in values)

if has_sharded_weight:
if not all(isinstance(v, ShardedWeight) for v in values):
raise TypeError(
"All values must be ShardedWeight if any value is ShardedWeight."
)
return True
else:
return False

作用:检查字典是否为分片状态字典

详细解释

# 检查规则:
# 1. 必须是字典类型
# 2. 如果任何值是ShardedWeight,则所有值都必须是ShardedWeight
# 3. 不允许混合类型

# 例子:
valid_dict = {
"weight": ShardedWeight(...),
"bias": ShardedWeight(...)
} # 返回 True

invalid_dict = {
"weight": ShardedWeight(...),
"bias": paddle.Tensor(...)
} # 抛出TypeError

应用场景

  • 验证检查点格式的正确性
  • 确保状态字典的一致性

4. 重叠区域计算工具

get_overlap_region(desc_offset, desc_shape, shard_offset, shard_shape)
def get_overlap_region(desc_offset, desc_shape, shard_offset, shard_shape):
ndim = len(desc_offset)
overlap_offset = []
overlap_shape = []
desc_starts = []
shard_starts = []
for i in range(ndim):
desc_lo = desc_offset[i]
desc_hi = desc_offset[i] + desc_shape[i]
shard_lo = shard_offset[i]
shard_hi = shard_offset[i] + shard_shape[i]
# overlap
lo = max(desc_lo, shard_lo)
hi = min(desc_hi, shard_hi)
if lo >= hi:
return False, None, None, None, None
overlap_offset.append(lo)
overlap_shape.append(hi - lo)
desc_starts.append(lo - desc_lo)
shard_starts.append(lo - shard_lo)
return True, overlap_offset, overlap_shape, desc_starts, shard_starts

作用:计算两个分片之间的重叠区域

详细解释

# 例子:2D张量
# desc: offset=(0,0), shape=(4,4)
# shard: offset=(2,2), shape=(4,4)

# 计算过程:
# 维度0:
# desc_lo=0, desc_hi=4, shard_lo=2, shard_hi=6
# lo = max(0,2) = 2, hi = min(4,6) = 4
# overlap_offset[0] = 2, overlap_shape[0] = 2
# desc_starts[0] = 2-0 = 2, shard_starts[0] = 2-2 = 0

# 维度1:
# desc_lo=0, desc_hi=4, shard_lo=2, shard_hi=6
# lo = max(0,2) = 2, hi = min(4,6) = 4
# overlap_offset[1] = 2, overlap_shape[1] = 2
# desc_starts[1] = 2-0 = 2, shard_starts[1] = 2-2 = 0

# 结果:
# 重叠区域:offset=(2,2), shape=(2,2)
# 在desc中的起始:(2,2)
# 在shard中的起始:(0,0)

应用场景

  • 计算不同分片策略间的数据重叠
  • 为数据复制提供精确的范围信息

5. 分片数据复制工具

assign_sharded_slice(src_desc, src_shard, dst_desc, dst_shard)
def assign_sharded_slice(src_desc, src_shard, dst_desc, dst_shard):
# 1. 计算源分片的重叠区域
src_has, _, overlap_shape, src_desc_starts, src_shard_starts = (
get_overlap_region(
src_desc.global_offset,
src_desc.local_shape,
src_shard.global_offset,
src_shard.local_shape,
)
)

# 2. 计算目标分片的重叠区域
dst_has, _, overlap_shape2, dst_desc_starts, dst_shard_starts = (
get_overlap_region(
dst_desc.global_offset,
dst_desc.local_shape,
dst_shard.global_offset,
dst_shard.local_shape,
)
)

# 3. 验证重叠区域一致性
assert src_has or dst_has, "no overlap!"
assert overlap_shape == overlap_shape2, "overlap shape mismatch!"

# 4. 执行数据复制
axes = list(range(len(overlap_shape)))

src_tensor_slice = paddle.slice(
src_shard.local_tensor,
axes=axes,
starts=src_shard_starts,
ends=[s + o for s, o in zip(src_shard_starts, overlap_shape)],
)

dst_tensor_slice = paddle.slice(
dst_shard.local_tensor,
axes=axes,
starts=dst_shard_starts,
ends=[s + o for s, o in zip(dst_shard_starts, overlap_shape)],
)

paddle.assign(src_tensor_slice, dst_tensor_slice)

作用:在不同分片间复制重叠数据

详细解释

# 完整流程:
# 1. 计算源分片与描述符的重叠区域
# 2. 计算目标分片与描述符的重叠区域
# 3. 验证两个重叠区域的一致性
# 4. 从源分片提取重叠部分
# 5. 复制到目标分片

# 例子:从tp2转换到tp4
# src_desc: 描述tp2时的分片布局
# src_shard: tp2时的实际数据
# dst_desc: 描述tp4时的分片布局
# dst_shard: tp4时的目标数据

# 函数会:
# 1. 找到tp2和tp4分片的重叠部分
# 2. 将tp2的重叠数据复制到tp4的对应位置

应用场景

  • 分片策略转换时的数据重分布
  • 检查点加载时的数据恢复

6. 信息合并工具

merge_shard_info_list(list_of_dicts)
def merge_shard_info_list(list_of_dicts):
merged = defaultdict(list)
for info in list_of_dicts:
for k, v in info.items():
merged[k].extend(v)
return dict(merged)

作用:合并多个分片信息字典

详细解释

# 例子:
list_of_dicts = [
{"param1": [info1, info2]},
{"param1": [info3], "param2": [info4]},
{"param2": [info5, info6]}
]

# 合并结果:
merged = {
"param1": [info1, info2, info3],
"param2": [info4, info5, info6]
}

应用场景

  • 收集所有rank的分片信息
  • 构建全局的分片视图

7. 描述符构建工具

build_shard_desc(val)
def build_shard_desc(val):
return ShardedWeightDesc(
key=val.key,
local_shape=tuple(val.local_shape),
global_shape=tuple(val.global_shape),
global_offset=tuple(val.global_offset),
)

作用:从ShardedWeight构建ShardedWeightDesc

详细解释

# 转换过程:
# 输入:ShardedWeight对象(包含实际数据)
# 输出:ShardedWeightDesc对象(仅包含元数据)

# 例子:
sharded_weight = ShardedWeight(
key="linear.weight",
local_tensor=paddle.Tensor(...), # 实际数据
local_shape=(1024, 512),
global_shape=(1024, 2048),
global_offset=(0, 0)
)

# 转换为:
shard_desc = ShardedWeightDesc(
key="linear.weight",
local_shape=(1024, 512),
global_shape=(1024, 2048),
global_offset=(0, 0)
)

应用场景

  • 提取分片权重的元数据信息
  • 用于分片信息的传输和存储

1.3 sharded_tensor的关键组件

class ShardedTensor:
"""
Represents a local shard of a distributed tensor parameter.

Args:
key (str): The name of the parameter.
local_tensor (Tensor): The local shard of the parameter.
local_shape (Tuple[int, ...]): The shape of the local shard.
global_shape (Tuple[int, ...]): The global logical shape of the parameter.
global_offset (Tuple[int, ...]): The offset of the local shard in the global parameter.
is_flattened (bool, optional): Whether the parameter has been flattened (used in sharding_v2 scenarios). Default is False.
flattened_range (slice, optional): If the parameter is flattened, this indicates the index range of the actual local shard within the local_tensor.
"""

def __init__(
self,
key: str,
local_tensor: Tensor,
local_shape: tuple[int, ...],
global_shape: tuple[int, ...],
global_offset: tuple[int, ...],
is_flattened: bool = False,
flattened_range: slice | None = None,
) -> None:
self.key = key
self.local_tensor = local_tensor
self.local_shape = local_shape
self.global_shape = global_shape
self.global_offset = global_offset
self.is_flattened = is_flattened
self.flattened_range = flattened_range

def __str__(self) -> str:
"""Returns a formatted string representation of the sharded tensor."""
return (
f"ShardedTensor(\n"
f" key={self.key},\n"
f" local_tensor={type(self.local_tensor).__name__}(shape={self.local_tensor.shape}),\n"
f" local_shape={self.local_shape},\n"
f" global_shape={self.global_shape},\n"
f" global_offset={self.global_offset},\n"
f" flattened_range={self.flattened_range}\n"
f")"
)

def shard_weight(
key: str,
weight: Tensor,
axis: int,
group: Group,
) -> ShardedTensor:
"""Creates a ShardedTensor by splitting the input tensor along a specified axis.

Args:
key: Unique identifier for the tensor.
weight: The input tensor to be sharded.
axis: The axis along which to shard the tensor.
group: The process group used for distributed communication.

Returns:
A ShardedTensor representing the local portion of the global tensor.
"""
if axis < 0 or axis >= len(weight.shape):
raise ValueError(
f"Shard axis {axis} is invalid for tensor with shape {weight.shape}"
)

# Get hybrid communication group and rank information
hcg = fleet.get_hybrid_communicate_group()
current_rank = group.rank
world_size = group.nranks

# Calculate shapes and offsets
local_shape = weight.shape
global_shape = deepcopy(local_shape)
global_shape[axis] = local_shape[axis] * world_size
global_shape = tuple(global_shape)
local_shape = tuple(local_shape)
global_offset = [0] * len(global_shape)
if world_size > 1:
global_offset[axis] = current_rank * local_shape[axis]
global_offset = tuple(global_offset)

return ShardedTensor(
key=key,
local_tensor=weight,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
)

def build_sharded_state_dict(
state_dict: dict[str, Tensor],
shard_rules: dict[str, int] | None = None,
prefix: str = "",
) -> dict[str, ShardedTensor]:
"""Converts a regular state dict to a sharded state dict based on sharding rules.

Args:
state_dict: The original state dictionary containing tensors
shard_rules: Dictionary mapping tensor names to their sharding axes.
If None, treated as empty dict (no tensor parallelism).
prefix: Optional prefix to prepend to all tensor keys

Returns:
Dictionary with the same keys as input but values converted to ShardedTensor
or regular Tensor based on sharding rules.

Note:
Tensors not in shard_rules will be wrapped as non-sharded ShardedTensors.
"""
shard_rules = shard_rules or {}
sharded_state_dict = {}

for key, tensor in state_dict.items():
full_key = f"{prefix}{key}" if prefix else key

if key in shard_rules:
# Apply tensor parallelism sharding
sharded_state_dict[full_key] = (
make_tp_sharded_tensor_for_checkpoint(
key=full_key,
tensor=tensor,
tensor_parallel_axis=shard_rules[key],
)
)
else:
# Create regular sharded tensor (non-tensor-parallel)
sharded_state_dict[full_key] = make_replicated_sharded_tensor(
key=full_key,
tensor=tensor,
)

return sharded_state_dict

​ 主要是ShardedTensor类和build_sharded_state_dictshard_weight两个接口,ShardedTensor主要是作为后续shard_state_dict中的基础单元,即{key:ShardedTensor},原来版本是普通的Tensor,而现在的ShardedTensor携带了Tensor切分的信息,主要是local_shapeglobal_shapeglobal_offset则可以据此对local_tensor进行全局tensor的重建,再对齐进行reshard。build_sharded_state_dict是在普通的state_dict的基础上,对于需要做分布式处理的(即shard)tensor进行切分标记,将tensor转化为ShardedTensormake_tp_sharded_tensor_for_checkpoint其实就是做mp参数并行,里面调用的就是shard_weight接口,返回一个ShardedTensor;对于不需要切分的,也要用make_replicated_sharded_tensor处理,将其转化为统一的ShardedTensor类,这部分处理无需调用shard_weight,直接返回ShardedTensorlocal_shape=global_shape,因为每个rank上保存的这部分数据都一样。而shard_weight,传入进来的tensor,对应切分的那个维度的数据,每个rank都不一样(对于shard组来说),因此将每个rank上该tensor的对应维度的shape加起来,即可得到global_shape,从而构造出具有分布式信息的Tensor。

1.4 load_state_dict的关键组件

1.4.1 get_rank_to_files(与原来的一致)

ef get_rank_to_files(
metadata_list,
local_data_files,
state_dict,
process_group,
use_dist,
mw_name_compatibility=True,
):
"""
Get the mapping of rank to its accessible files.
"""

# The necessary files to be read
tensor_key_list = []
necessary_files = []
mw_name_compatibility_mapping = {}

for metadata in metadata_list:
for local_tensor_index, file_name in metadata.storage_metadata.items():
assert (
local_tensor_index not in tensor_key_list
), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata."
tensor_key_list.append(local_tensor_index.tensor_key)
if local_tensor_index.tensor_key in state_dict:
necessary_files.append(file_name)

all_necessary_files = []
if use_dist:
paddle.distributed.all_gather_object(
all_necessary_files, necessary_files, process_group
)
else:
all_necessary_files.append(necessary_files)

global_necessary_files = [
file for files in all_necessary_files for file in files
]

global_necessary_files_set = set(global_necessary_files)
if len(global_necessary_files_set) <= 0:
logger.warning(
"No necessary data files found in the checkpoint directory. Please check the metadata."
)
missing_keys = set(state_dict.keys())
return {}, missing_keys, mw_name_compatibility_mapping

# allgather all accessible files
global_data_files = []
if use_dist:
paddle.distributed.all_gather_object(
global_data_files, local_data_files, process_group
)
else:
global_data_files.append(local_data_files)
tmp = []
for files in global_data_files:
tmp += files
global_data_files_set = set(tmp)
logger.debug(
f"necessary_data_files_set:{global_necessary_files_set}, global_data_files_set:{global_data_files_set}"
)
# check necessary files in global_data_files
assert (
global_data_files_set & global_necessary_files_set
== global_necessary_files_set
), f"The checkpoint files are not complete. Please check the checkpoint directory. global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}"
missing_keys = set(state_dict.keys()) - set(tensor_key_list)
if len(missing_keys) > 0:
if mw_name_compatibility:
mw_name_compatibility_mapping = _modify_mw_name_for_compatibility(
state_dict, missing_keys, tensor_key_list
)
if len(missing_keys) > 0:
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
else:
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)

rank_to_files = {}
for rank, need_files in enumerate(all_necessary_files):
seen = set()
unique_need_files = [
f for f in need_files if not (f in seen or seen.add(f))
]
rank_to_files[rank] = unique_need_files
logger.debug(f"mapping rank_to_files:{rank_to_files}")
return rank_to_files, missing_keys, mw_name_compatibility_mapping

​ 根据保存的storage_metadata,遍历当前rank上的state_dict,根据local_tensor_index.tensor_key是否在state_dict中,来确定是否需要当前local_tensor_index对应的文件,如果需要就添加到necessary_files中,all_necessary_files保存的是所有rank的necessary_files,如下:

all_necessary_files = [
["0_0.distcp", "1_0.distcp"], # rank 0 需要的文件
["2_0.distcp", "3_0.distcp"], # rank 1 需要的文件
["4_0.distcp", "5_0.distcp"], # rank 2 需要的文件
["6_0.distcp", "7_0.distcp"], # rank 3 需要的文件
]

​ 即key就是rank id,value就是该rank需要的文件列表,seen是用来去重的。

1.5 paddlenlp适配

结论:因为只有 LlamaLMHead 的分片保存规则在本文件里需要“特殊约定”,其它层要么已经在各自实现里内建了 sharded_state_dict,要么可以用默认递归收集;而 LM Head 需要显式告诉检查点系统“按哪一维切”。

为什么只有 LlamaLMHeadPipelinePretrainedModel 需要适配?

1.LlamaLMHead
  • LM Head 的权重轴不固定LlamaLMHead 支持 transpose_y 和词表并行(vocab parallel)。这会改变权重逻辑形状与“被切分的维度”:

    • transpose_y=Truetie_word_embeddings 时,weight 形状是 [vocab_size, hidden_size],切分轴应为 axis=0

    • 否则通常是 [hidden_size, vocab_size],切分轴应为 axis=1

    • 代码中专门计算了 axis = 0 if self.transpose_y else 1,然后:

      # L2000-L2006
      state_dict = self.state_dict(structured_name_prefix="")
      return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)

      这一步确保统一检查点能正确记录“词表维度”的切分方式,便于跨并行策略重构权重。

  • 其它模块已有分片实现或可用默认机制

    • 注意力/MLP里用的 ColumnParallelLinearRowParallelLinear(以及对应的 Sequence Parallel 版本)在它们各自的实现里已经处理了分片参数保存;模型其他权重(如 LlamaRMSNorm.weight)不涉及并行切分轴的歧义,默认递归即可。
    • 词嵌入 VocabParallelEmbedding 也在并行库里有自己的分布式属性与导出路径。
  • LM Head 还涉及权重共享与并行输出

    • tie_word_embeddings 时和 Embedding 共享权重,且 is_distributed/split_axis 被设置用于张量并行。
    • 因此 LM Head 成为“需要显式声明切分轴”的最特殊一层,避免统一检查点在重构/重分片(如从 TP2 切换到 TP4)时出错。
2.PipelinePretrainedModel

结论:因为只有 LlamaLMHead 的分片保存规则在本文件里需要“特殊约定”,其它层要么已经在各自实现里内建了 sharded_state_dict,要么可以用默认递归收集;而 LM Head 需要显式告诉检查点系统“按哪一维切”。

  • LM Head 的权重轴不固定LlamaLMHead 支持 transpose_y 和词表并行(vocab parallel)。这会改变权重逻辑形状与“被切分的维度”:

    • transpose_y=Truetie_word_embeddings 时,weight 形状是 [vocab_size, hidden_size],切分轴应为 axis=0

    • 否则通常是 [hidden_size, vocab_size],切分轴应为 axis=1

    • 代码中专门计算了 axis = 0 if self.transpose_y else 1,然后:

      # L2000-L2006
      state_dict = self.state_dict(structured_name_prefix="")
      return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)

      这一步确保统一检查点能正确记录“词表维度”的切分方式,便于跨并行策略重构权重。

  • 其它模块已有分片实现或可用默认机制

    • 注意力/MLP里用的 ColumnParallelLinearRowParallelLinear(以及对应的 Sequence Parallel 版本)在它们各自的实现里已经处理了分片参数保存;模型其他权重(如 LlamaRMSNorm.weight)不涉及并行切分轴的歧义,默认递归即可。
    • 词嵌入 VocabParallelEmbedding 也在并行库里有自己的分布式属性与导出路径。
  • LM Head 还涉及权重共享与并行输出

    • tie_word_embeddings 时和 Embedding 共享权重,且 is_distributed/split_axis 被设置用于张量并行。
    • 因此 LM Head 成为“需要显式声明切分轴”的最特殊一层,避免统一检查点在重构/重分片(如从 TP2 切换到 TP4)时出错。

2.对相关的分布式API添加shard_state_dict处理

2.1 VocabParallelEmbedding

2.1.1 接收的输入

​ 文本输入

用户输入: "Hello world, how are you?"

​ 分词(Tokenization)

分词结果: ["Hello", "world", ",", "how", "are", "you", "?"]

​ 词汇表映射(Vocabulary Mapping)

词汇表: {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3, 
"Hello": 4, "world": 5, ",": 6, "how": 7, "are": 8, "you": 9, "?": 10, ...}

映射结果: [4, 5, 6, 7, 8, 9, 10]

​ 输入到模型为词汇ID序列

模型接收的输入: x = [4, 5, 6, 7, 8, 9, 10]  (词汇ID序列)

因此,VocabParallelEmbedding接收到的输入x是[batch_size,seqlenth],即多组词汇ID序列。

2.1.2 处理输入

假设vocab_size=50000,embedding_dim=1024,即有50000个词,映射成向量用1024个特征表示,每个词对应一个1024长度的特征向量:

每个词汇ID对应矩阵中的一行:
word_id=0 -> W[0, :] = [0.1, 0.2, 0.3, ..., 0.1024]
word_id=1 -> W[1, :] = [0.5, 0.1, 0.8, ..., 0.2048]
word_id=2 -> W[2, :] = [0.3, 0.7, 0.2, ..., 0.3072]
...
word_id=499999 -> W[499999, :] = [0.9, 0.4, 0.6, ..., 0.1024]

输入为:

# 输入: x = [batch_size, seq_len] (词汇ID)
# 例如: x = [[100, 250000, 500000, 750000],
# [150, 250100, 500100, 750100]]

Vocab分割后:

GPU0: W[0:250000, :]     (250000行,1024列)
GPU1: W[250000:500000, :] (250000行,1024列)
GPU2: W[500000:750000, :] (250000行,1024列)
GPU3: W[750000:1000000, :] (250000行,1024列)

并行化后的查找过程:

对于输入词汇ID,每个GPU的处理:

GPU0 (负责词汇0-249999):
- 输入ID=100: 查找 W[100, :] = [0.1, 0.2, ..., 0.1024]
- 输入ID=150: 查找 W[150, :] = [0.3, 0.4, ..., 0.1024]
- 输入ID=250000: 不在范围内,返回零向量或特殊处理
- 输入ID=500000: 不在范围内,返回零向量或特殊处理

GPU1 (负责词汇250000-499999):
- 输入ID=100: 不在范围内,返回零向量
- 输入ID=250000: 查找 W[250000, :] = [0.5, 0.6, ..., 0.1024]
- 输入ID=250100: 查找 W[250100, :] = [0.7, 0.8, ..., 0.1024]
- 输入ID=500000: 不在范围内,返回零向量

GPU2 (负责词汇500000-749999):
- 输入ID=500000: 查找 W[500000, :] = [0.9, 0.1, ..., 0.1024]
- 输入ID=500100: 查找 W[500100, :] = [0.2, 0.3, ..., 0.1024]

GPU3 (负责词汇750000-999999):
- 输入ID=750000: 查找 W[750000, :] = [0.4, 0.5, ..., 0.1024]
- 输入ID=750100: 查找 W[750100, :] = [0.6, 0.7, ..., 0.1024]

最终将每张卡的结果做allreduce合并,则得到最终结果,输出为:[batch_size, seq_len, embedding_dim]

​ 一开始该层权重是随机初始化的,即,每个词虽然都用向量表示,但此时是无意义的,经过训练后,相近的词embedding的数据会逐渐相似,从而在推理时,正确找到每个词的embedding。

2.2 ColumnParallelLayer与RowParallelLayer同时使用的关系

2.2.1 ColumnParallelLayer

columnparallel.drawio

2.2.2 RowParallelLayer

Rowparallel.drawio

​ 可以看到,RowParallelLayer在计算的过程中,需要把输入拆分成两列分别在两张卡上做计算,最终两张卡都得到Parital状态的数据,而如果上一层是ColumnParallel则其计算的结果刚好分配到两个设备上(即结果被按列切分),而此结果正是RowParallelLayer需要的输入,那么就无需做通信,直接继续计算最后再做allreduce即可。

2.2.3 ColumnParallelLayer与RowParallelLayer的w和bias的切分

层间计算.drawio

​ 注意,在做y=x*W^T+b的计算时,首先乘积得到的数据是[batchsize,output_size],每一行表示一个数据,而bias是分别和每一行相加,因此bias是一个一维的向量,因此,当W按列切分时,bias需要按行切分,从而保持正确的计算关系。

​ 当添加了bias的时候,做RowParallelLayer和ColumnParallelLayer情况如下:

RowParallelLayer:

RowParallel_bias.drawio

​ RowParallelLayer只切w,不切bias

ColumnParallelLayer:

ColumnParallel_bias.drawio

​ ColumnParallelLayer切w的axis=1,切bias的axis=0

2.3 DygraphShardingOptimizerV2

核心目标

sharded_state_dict 是为了解决不同并行策略间状态转换的问题,以及V2情境下,optimizer被展开铺平的问题:

  • 例如从 tp2 切换到 tp4:需要重新划分参数
  • 保持数据完整性:确保参数和优化器状态正确转换
  • 支持断点续训:在不同并行配置间无缝切换
实现方法总结
1. 分片信息收集阶段
# 第一步:收集当前分片策略的信息
for comm_group, buffers in comm_group_buffers.items():
for buffer in buffers:
for param_name, grad_view in buffer._sharding_param_grad_view.items():
# 记录每个参数在当前rank的分片范围
param_slice_info[param_name] = (
grad_view._param_begin, # 分片起始位置
grad_view._param_end, # 分片结束位置
)
# 记录参数的完整形状信息
param_shape_info[param_name] = (
grad_view._param.shape, # 原始形状
grad_view._param.numel().item(), # 元素总数
grad_view._index, # 分片索引
grad_view._padded_size, # 填充大小
)

目的:记录当前分片策略下每个参数如何被分配到各个rank。

2. 全局信息同步阶段
# 第二步:收集所有rank的分片信息
for comm_group, buffers in comm_group_buffers.items():
# 从当前rank收集信息
param_slice_info["sharding_rank"] = comm_group.rank

# 通过all_gather收集所有rank的信息
gathered_info = []
paddle.distributed.all_gather_object(
gathered_info, param_slice_info, group=comm_group
)
all_rank_slice_info.extend(gathered_info)

目的:让每个rank都知道完整的分片分布情况,为后续重建做准备。

3. 部分分片张量识别阶段
# 第三步:识别哪些张量是部分分片的
for param_key, tensor in optim_state_dict.items():
base_name, _ = _generate_base_static_name(param_key)

if int(tensor.numel()) > 1: # 非标量张量
begin, end = merged_slice_info[base_name]
shape_info = merged_shape_info[base_name]

# 判断是否为部分分片:分片大小 < 原始大小
if shape_info and end > begin and end - begin < shape_info[1]:
partial_tensor_names.append(base_name)

目的:区分完全分片和部分分片的张量,它们需要不同的处理策略。

4. 偏移映射计算阶段
# 第四步:计算每个rank在完整张量中的偏移位置
for tensor_name in partial_tensor_names:
offset_mapping[tensor_name] = [0] * world_size

# 记录每个rank的分片大小
for info in all_rank_slice_info:
if tensor_name in info:
begin, end = info[tensor_name]
if end > begin:
offset_mapping[tensor_name][info["sharding_rank"]] = end - begin

# 转换为累积偏移
running_total = 0
for rank in range(world_size):
current_size = offset_mapping[tensor_name][rank]
offset_mapping[tensor_name][rank] = running_total
running_total += current_size

目的:为每个rank计算其在完整张量中的起始位置,用于重建完整张量。

5. 状态字典构建阶段
# 第五步:构建分片状态字典
for param_key, tensor in optim_state_dict.items():
base_name, optim_state_type = _generate_base_static_name(param_key)
struct_name = static_to_struct[base_name]
sharded_param = model_sharded_state_dict[struct_name]
unified_name = f"{struct_name}.{optim_state_type}"

# 处理三种不同类型的张量
if int(tensor.numel()) == 1:
# 标量参数:直接保存
sharded_weight = ShardedWeight(...)
elif base_name in partial_tensor_names:
# 部分分片张量:记录在完整张量中的位置
flattened_offset = offset_mapping[base_name][sharding_rank]
sharded_weight = ShardedWeight(
flattened_range=slice(flattened_offset, flattened_offset + int(tensor.numel()))
)
else:
# 完全分片张量:当前rank拥有完整分片
sharded_weight = ShardedWeight(
flattened_range=slice(0, int(tensor.numel()))
)

目的:为每个优化器状态创建包含完整分片信息的 ShardedWeight 对象。

关键设计思想
1. 分层信息记录
# 记录三个层次的信息:
# 1. 参数级:param_slice_info - 分片范围
# 2. 形状级:param_shape_info - 完整形状
# 3. 全局级:offset_mapping - 全局偏移
2. 分类处理策略
# 三种处理策略:
# 1. 标量参数:直接保存,无需分片信息
# 2. 部分分片张量:记录在完整张量中的位置
# 3. 完全分片张量:当前rank拥有完整分片
3. 全局视角构建
# 每个rank都收集全局信息:
# 1. 所有rank的分片范围
# 2. 完整的参数形状
# 3. 全局偏移映射

2.4 SP(序列并行)

​ 与ColumnParallel、RowParallel类似,只是维度发生在seq_len,且伴随tp(mp)使用。

2.5 关于shared_state_dict方法中structured_name_prefix为空的问题

image-20250906155027601

​ 实际上在这里会递归调用sub_layer的shared_state_dict方法,从而将当前层的name传递到sub_layer作为前缀。

3.测试Ernie中的一些问题

1.self.args.offload_optim

​ _offlad_optimizer导致保存的转换后的optimizer.pdopt中的动量都没保存成功

​ 模型转换时,如dp2->dp4,offload_optimizer处理后,此时state_dice()中只有master_params和shceduler的数据,动量都被卸载到cpu上了,导致保存失败。

image-20250818005155527

2.sharding4转纯dp2时,文件名不对应,无法加载

image-20250818005805429

​ shading对应保存的文件名是model_state_shardxx,但是纯dp加载的model文件名是model_state.paparams,因此保存shading4的ckpt,而此时换成纯dp2训练时,无法正确加载ckpt文件。

3.checkpoint文件路径问题

image-20250818005836965

  • checkpoint文件路径修改后,paddleformers得同步更新导入得load_state_dict和save_state_dict

4.MoElayer找不到config属性

c3c941b14e6dd9f31a799f97c6cf504e

原因:

82A5F2B7DE7C6CAC8E68029E33533D5D

​ 这块是因为还没定义就用了那个config的一些参数,我直接给注释掉了。

5.纯sharding出错,原因是在梯度累加时累加的数据类型有问题

img

img

​ 注释掉的为原来的代码,然而测试了最新的paddle发现,add_已经适配了fp16和float32两个不同精度的数相加的场景,估计是当时合入的pr造成的bug,已经被修复。

6.t2(ep2)->pp4,报源ckpt加载后的数据,缺少某个参数的优化器状态

​ lm_head与embedding共享一份weight,因此优化器内部的优化器状态也只有一份

​ 这会导致,在加载ckpt的时候,报错:

image-20250902162332049

主要原因

image-20250902195016450

​ 在加载ckpt时,需要初始化model和opt,而使用flex_ckpt框架时,对应的init_opt中是根据model里面的每个key来创建对应的opt状态,,此时embed_layer和lm_head_layer共用同一个参数,因此优化器状态只有一份,所以导致在ckpt中找不到初始化时创建的embed_tokens,导致报错。

问题追溯:

​ 打印出的model及其对应的value:

image-20250902193601212

image-20250902194929965

可以看到都指向同一个tensor,而具体实现在tie_weight:

image-20250902194128545

image-20250902194218577

image-20250902194256847

可以看到在这里面将lm_head 直接赋值为embedding对应的tensor

为什么共用一个参数,他们也共用一份优化器状态?

image-20250903214741558

image-20250903214815348

image-20250903215139624

image-20250903214850026

​ 以上是创建optimizer涉及到得流程,可以看到,optimizer中包含的参数,是根据params来去重的,即直接根据Tensor去重,而不是key,因此共享tensor的参数,只会有一份保留在optimizer的参数列表中,并且是第一次出现的参数。

image-20250903220104385

image-20250903220001562

image-20250903215229083

在创建累加器时,此时只有{key:embedding_0.w_0,shared_tensor}保留下来了,所以只有embedding的优化器状态创建了,就不会再创建lm_head的了,打印出来如下:

image-20250902194529172

为什么报错提示找不到embedding的优化器状态,而不是lm_head的优化器状态?

image-20250902194719815

在这里,因为两者的v.local_tensor.name一致,前者被覆盖了。

问题总结:

tp2(ep2)->pp4问题总结: 遇到的问题: 在pp4 load tp2(ep2)保存的ckpt时,加载AOAEngine,调用shape_propagation函数时,未被AOA规则改写的参数会做补全映射,而此时会判断补全的这个key是否在源策略(tp2ep2)中出现过,若没出现过则会报错,而此处就报错:找不到 ernie.embed_tokens.weight.moment1_0(其实所有的embed_tokens.weight相关的优化器状态都找不到)。 原因总结: 在初始化opt的函数中即init_optmizer();会根据当前加载的model参数初始化优化器状态,每一个参数都会为其创建优化器状态,而在ernie4.5非pp的组网中,会使用tie_weight函数使得lm_head与embed相关的两个参数共享同一份tensor,而在训练tp2(ep2)创建优化器状态时,相同param.name的param,只会创建一份优化器状态,并且以第一次出现的key来创建优化器状态参数对应的名称,因此确实lm_head与embed仅仅只有一份权重才对,因此此处是需要优化init_opt部分的逻辑。 然而针对上述逻辑,最终应该是能找到embed相关的优化器状态,而找不到lm相关的优化器状态才对,经过查证,问题在于,AdamW的sharded_state_dict在创建static_to_struct_mapping映射时,未对共用同一个tensor的参数做判断,导致对于共享同一个weight的layer来说,后面layer的参数名(即key)会把前面layer的参数名给覆盖,lm_head在后面,因此覆盖了embed,导致我们在优化器看到的是只有lm_head的优化器状态。因此这里需要优化的是,dygraph_sharding_optimizer和AdamW内的sharded_state_dict函数的逻辑。 但针对ernie的pp组网,查证后发现,并未支持tie_weights操作,lm_head和embed分别独立一份weight;而非pp组网,默认一定调用tie_weights操作,因此在当前情况下,无法做tp2(ep2)->pp4的转换。

7.bias开false时,会遇到报错

B05DD129A49C889F58E11192632DA31E

​ 主要原因是,这里直接对bias做scale,然而当bias为None时,是无法做scale的,导致出错。

image-20250903124127468

​ 做如下修改即可:

image-20250903124234252

8.测tp2(ep2)->tp4,有一个参数的md5未对齐

F311AF56AE807889BA42CCF80274BDE9

最终总结:FLAGS_shard_bypass_dygraph_optimizer 标志位只能控制优化器本身的参数更新,但无法阻止回调函数中的直接参数操作更新。在 on_optimizer_end 阶段,OrthogonalCallback 会计算正交损失并直接更新 ernie.layers.1.mlp.gate.weight 参数,这种直接参数修改绕过了优化器控制机制,导致该参数在 save/load 转换过程中被意外更新,从而造成 MD5 校验失败。证明了我们的FlexCheckpoint框架逻辑没问题。 img img

9. 测tp2(ep2)->pp4,有多个参数的md5未对齐

image-20250904183600844

​ 发现是因为moe模式下,开了image-20250904184242078的原因,导致moe模式下的模型,会多出一个moe_statics.e_score_correction_bias的参数。

10.tp2(ep2)转vpp4,暂时有问题,num_hidden_layer配8层,9层都不对

image-20250905162818225

主要是:multi_token_pred_depth参数开启时,会在hidden_layer中多加一层MTP层,导致以下断言不支持,当前ernie4.5的vpp不支持加MTP这一层:

assert sum(weights) % actual_num_parts == 0

但是关掉multi_token_pred_depth参数后,又报p2p通信错误,看起来是ernie4.5跑vpp自身的bug:

image-1757507777538

11.tp2(ep2)转tp2(ep2)+sd2 和 dp2转tp2(ep2)+sd2时,会出现,某些优化器状态消失得现象

4fec934f-1275-4851-b927-61e60f71c244

也是开了opitimizer_offload的原因

12.DP2转DP4

image-1757510456412

在ernie下跑会报错,因为opt没被封装,它没有inner_opt,而llama下跑不会报错,因为llama下封装了一层:


paddle.distributed.fleet.utils.mix_precision_utils.MixPrecisionOptimizer



因此需要加个判断:

inner_opt = getattr(optimizer, "_inner_opt", None)
if DygraphShardingOptimizer is not None and isinstance(inner_opt, DygraphShardingOptimizer):
local_params = optimizer._rank2params[optimizer._sharding_rank]
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
return

elif DygraphShardingOptimizerV2 is not None and isinstance(inner_opt, DygraphShardingOptimizerV2):

注意:ernie4.5训练时,train函数中调用的self._wrap_model是/home/ERNIE/examples/pre-training/ernie/src/trainers/pretraining_trainer.py内的方法。

13.DP2转Sharding4_V1的时候(开dp_group和sharding_group)

image-1757522487349

总共64个专家,2卡sharding的时候,只有32个专家有优化器状态,4卡sharding的时候只有16个专家有优化器状态;原因是每个rank上的experts组在训练一次后,所有专家的参数被同步了,未具体定位,但训练5步,4张卡上的16个专家参数的md5完全相同;但初始化时,4张卡上的专家参数是不同的。

image-c4a8d393-a7fa-4c6f-95a1-53f665fdc6c1 image-cf8d81f9-c85a-4a0a-825e-7dcbd8cb9b86

每张卡都做了同样的操作,deepcopy fn,而对每个rank来说,这个fn都是同一个layer。 而DP的时候,每个rank image-5aaa869a-f3be-46c0-bf53-0c15f4950234

14.DP2转Sharding4_V2的时候

报错如下: e00ae98c-c570-4371-b75e-8c453814fb4a

15.DP2转Sharding2_EP2

报错如下,主要是在sharding2_ep2转回dp2时报错,但是接续loss 1E-5对齐: 65538414-c00e-485e-bc52-5153491609e8

其实是开了这个的原因 16148

16.DP2、ShardingV1、V2转TP2(EP2)_PP2

存在下面的报错: 7c2756ec-e49e-4db9-bc0d-12e2413ffc9c cbc5d930-d0b4-4db7-af05-e10422ed14d6

主要原因是SequenceParallelLayer没有适配sharded_state_dict 15315

4.测试LLAMA中的一些问题

1.路径需要更换

image-20250821180914606

2.sharding__stage_1_overlap不支持

image-20250821181717133

3.纯tp2时报错:

image-20250821183527560

AdamW需要适配sharded_state_dict

4.纯DP下,fleet显示没有正确初始化

0340047CDE09863031C12D01282B0B0D

​ 报错的原因是,在纯dp的情况下比较特殊,要开unified_checkpoint这参数,才能用fleet.init初始化,不然用的动半的初始化,这样的话self._hcg没有初始化,就不能调用get_hybrid_communicate_group,加上就好了。

image-20250901195619880

5.纯dp会hang住

现象:

image-20250826002129383

原因:

image-20250826002423188

image-20250826002506576

image-20250826002352174

image-20250826002244403

image-20250826002306295

在调用check_unique_id函数时,会调用all_gather获取所有rank的unique_id,而因为纯dp下,should_save被设置为只在0卡保存权重,因此1卡是不会进入save_state_dict中的,而在调用all_gather时,如果process_group为None,则会调用global_group,纯dp2下,即ranks:2 rank_id:0,1;而此时0卡向1卡all_gather请求,1卡却没有做all_gather,0卡就一直等待,最终导致hang住。

​ 解决方案:

image-20250901200028655

​ 添加一个条件,在纯DP时,此时use_hybrid_parallel为false(这是每个rank共同的特征),因此,添加个判断条件,让1卡也进入即可。

6.flash_attention无法正确传入,要手动修改

image-20250901195429250

7.测tp时,fused_qkv, fused_ffn打开后loss接续不符合预期

​ 在测tp策略转换的过程中,发现fused_qkv, fused_ffn打开后loss差距不符合预期;经验证,当前存在fused_qkv与old_fused_qkv两套逻辑,llama当前默认使用的是old_fused_qkv此时无需配置aoa与tp自洽,而ernie使用的是fused_qkv,需要配置aoa。

​ llama下的aoa配置:

 --aoa_config '{ 
"aoa_statements": [
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0, fused_ffn"
]
}' \

​ ernie下的aoa配置:

    aoa_config: {
"aoa_statements": [
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0 -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0 -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.w_0 -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.w_0, fused_qkv, num_heads=20, num_key_value_groups=5",

"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0 -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0 -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.w_0 -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.w_0, fused_qkv, num_heads=20, num_key_value_groups=5",

"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight, fused_ffn",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0 -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0, fused_ffn",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0 -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0, fused_ffn",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.w_0 -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.w_0, fused_ffn",

"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0 -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0 -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.w_0 -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.w_0, fused_ffn",

"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment1_0 -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment1_0, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment2_0 -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment2_0, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.w_0 -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.w_0, fused_ffn",
]
}

fused_qkv(llama)实现逻辑图:

image-20250906104251807

tp2->tp4,num_heads=k_v_nums:

tp2_qkv_fused-_tp4_qkv_fused.drawio

tp2->tp4,num_heads>k_v_nums:

image-20250906112745567

old_fused_qkv(ernie)实现逻辑图

image-20250906104313466

image-20250906104354160

tp2->tp4,num_heads=k_v_nums: 此时逻辑同上,也是均分最后一维。

tp2->tp4,num_heads>k_v_nums:

image-20250906131218483

5.unified_checkpoint与flex_check_point的区别

以tp2为例,flex_check_point保存的权重,是按照参数的部分分片保存的,并没有在最后做allgather:

image-20250905112632509

image-20250905112656881

​ 可以看到,保存下来的embed_tokens参数,仍然是按照vocab_size的大小,切成两份的形式,是一种shard的状态,注意,这里看起来像是batch被切分了,其实是因为,参数一般都以[vocab_size,batch_size]的形式排列,为了后续方便计算。

​ 而unified_ckpt保存的权重,最后会做allgather,即所有rank上的参数都是完整参数,而保存的时候,是将所有参数划分成tp_degree份保存到多个文件中:

image-20250905113611347

6.代码改进意见

1.关于Layer的shared_state_dcit

因为ColumnParallelLinear、VocabParallelEmbedding、RowParallelLinear本质也是继承nn.Layer,因此可以统一用同一个shared_state_dict。添加一个如下的切分方法即可:

def _get_shard_rules(self):
"""子类可重写此方法来提供分片规则"""
return None

image-20250906154640907

7.一些代码记录

1.ShardingGradView


class ShardingGradView:
def __init__(
self,
param,
param_buffer,
grad_buffer,
index,
padded_size,
sharding_degree,
rank,
use_main_grad=False,
release_grad=False,
):
self._param = param
self._param_buffer = param_buffer
self._grad_buffer = grad_buffer
self._index = index
self._padded_size = padded_size
self._sharding_degree = sharding_degree
self._rank = rank
self._use_main_grad = use_main_grad
self._release_grad = release_grad
shard_size = param_buffer._numel() // sharding_degree
rank_begin = max(rank, 0) * shard_size
rank_end = rank_begin + shard_size

param_begin = max(self._index, rank_begin)
param_end = min(self._index + self._padded_size, rank_end)
self._param_begin = param_begin
self._param_end = param_end
self._rank_begin = rank_begin

self._slice_grad = None

if not self._release_grad:
self._link_grad_to_buffer()

# share param buffer
self._share_param_buffer()
def _get_padding(self):
if self._param_begin < self._param_end and self._slice_grad is not None:
padding_start = self._index + self._param._numel()
padding_end = self._index + self._padded_size
padding_start = max(self._param_begin, padding_start)
padding_end = min(self._param_end, padding_end)

if padding_start >= padding_end:
return None

padding = padding_end - padding_start
grad_numel = self._slice_grad._numel()
assert grad_numel >= padding, f"{grad_numel} vs {padding}"
padding_grad = self._slice_grad._slice(
grad_numel - padding, grad_numel
)
return padding_grad
else:
return None
def _slice_grad_from_buffer(self):
assert self._grad_buffer is not None
if self._param_begin < self._param_end:
self._slice_grad = self._grad_buffer._slice(
self._param_begin, self._param_end
)
tmp_grad = self._grad_buffer._slice(
self._index, self._index + self._param._numel()
)
return tmp_grad
def _link_grad_to_buffer(self):
tmp_grad = self._slice_grad_from_buffer()
tmp_grad.get_tensor()._set_dims(self._param.shape)
if not self._use_main_grad:
self._param._copy_gradient_from(tmp_grad)
else:
self._param.main_grad = tmp_grad
这里是根据padded_param减去param的大小,从而得到参数的padding大小,又由于grad和param大小是一致的,因此,切出grad尾部的padding大小的这一部分,就是padding_grad。
可以看到,self._slice_grad在初始化的时候,用的self._param_begin, self._param_end这一段,而这是经过了padded的大小,因此self._slice_grad.numel()等于padded后的param的大小。
而tmp_grad是实际的未padded的大小。也就是参数实际的大小,并且self._grad_buffer和self._param_buffer都是用paddle.zeros创建的全0矩阵,大小为当前param组所有param经过padded之后的大小的和。
因此在_link_grad_to_buffer中,实际是将tmp_grad赋值给param的grad属性,如果有main_grad就给main_grad,这里是浅拷贝,所以就和buffer共享内存了。
def _share_param_buffer(self):
param_shape = self._param.shape
stop_gradient = self._param.stop_gradient
self._param.stop_gradient = True
self._param.flatten_()
paddle.assign(
self._param,
self._param_buffer._slice(
self._index, self._index + self._param._numel()
),
)
self._param.get_tensor()._set_dims(param_shape)
self._param.stop_gradient = stop_gradient
self._param_buffer._slice(
self._index, self._index + self._param._numel()
)._share_buffer_to(self._param)

def fill_slice_param(self, slice_param):
slice_begin = self._param_begin
slice_end = self._param_end
if slice_param._is_initialized():
assert self._param_buffer._is_shared_buffer_with(slice_param)
assert len(slice_param.shape) == 1
assert slice_param.shape[0] == (slice_end - slice_begin)
slice_begin = self._param_begin
slice_end = self._param_end
slice_buffer = self._param_buffer._slice(slice_begin, slice_end)
slice_buffer._share_buffer_to(slice_param)
slice_param.get_tensor()._set_dims([slice_end - slice_begin])

def assign_slice_grad(self, slice_param):
assert self._param_buffer._is_shared_buffer_with(self._param)
slice_grad = self._slice_grad
if slice_grad is None:
return
self.fill_slice_param(slice_param)
if hasattr(self._param, "main_grad"):
if not hasattr(slice_param, "main_grad"):
slice_param.main_grad = slice_grad
else:
assert slice_param.main_grad is slice_grad
elif slice_grad is not None:
if slice_param.grad is None:
slice_param._copy_gradient_from(slice_grad)
else:
assert slice_param.grad._is_shared_buffer_with(slice_grad)

二者的区别是,前者是初始化的时候,将param_buffer的内存与param共享,注意是未padded的,而后者则是padded后的param,并且在当超过当前rank分配的buffer大小时会被截断,而前者不会。 为什么有这两个? 前者是给初始化param使用的,将param和buffer的内存共享,而后者是针对slice_param = EagerParamBase(shape=[1], dtype=param.dtype) 使用得,初始时,不知道大小,通过fill_slice_param来设置实际大小,即,因为每个优化器负责更新部分参数,大多数参数是完整的,但当self._index + self._padded_size超过rank_end时,此时param_end会被设置为rank_end,即将参数截断,一部分在当前rank,另一部分在下一个rank上。这里用EagerParamBase实际是想要实现深拷贝原始的param,原始也是此类型,这样可以使得能够将所有属性完全相同的复制过来。

所以model的param其实是没有被切分的,而slice_param在不跨rank的时候,是完整param的区域+padding区域,而如果出现跨rank的情况,则会被截断,前者用于forward完整计算不切,后者用于仅更新当前rank上的optimizer负责更新的参数,因此做跨rank切分。对应的self.param.main_grad和slice_param.main_grad也是上述区别。

可以看到sharding_param_grad_view中保存了多个ShardingGradView实体,而每个ShardingGradView实体的信息,都能体现当前param的是如何被sharding到每个rank上的优化器上的,即其在param_buffer中的位置信息,因此我们利用这个信息,去获取每个param在当前rank上的切片信息(注意这里是opitimizer要更新的参数的切片信息,而不是对model的param做了切分,实际是复制了一份,可见前文介绍),即flattened_range。

2.opitmizer中的param在每个rank上的划分

可以看到,只要是不属于当前rank的切片,要么①param_begin>param_end;②param_end-index<0。则slice的star,end都取param_begin-index,做空切片,从而控制param_slice_info只保留自己rank上的param的有效切片信息。

3.FusedCommBuffer的逻辑

注意单个fuedCommBuffer并不对应全部参数,optimizer的参数列表会被划分到多个buffer中

deepseek_mermaid_20250906_c19101

4.FusedCommBuffer在DygraphSharding_Optimizer中的应用(v1,v2都用到了,但是只有V2对optimizer的参数做了摊平切分,而V1只是对完整参数做了划分。)

DygraphSharding与DygraphShardingV2的区别是,前者是参数级划分,将opitimizer的参数划分到不同rank上(且为了均衡负载,是依次分配参数,及分配下一个参数是,是看哪个rank上当前分配的参数大小最小,则分配,参数是乱序的);而后者是参数内划分,会将参数列表根据color分成多个buffer,然后每个buffer内的所有参数flatten后,sharding到多个rank上,所以会存在有的参数一部分在该rank,另一部分在下一个rank的现象。DygraphShardingV2不支持fused_param

deepseek_mermaid_20250906_ef7cb1 注意,会根据group_size创建多个FusedCommBuffer,同时这里的tensor_fusion,指的是,比如把多个参数放到同一片存储区域,多个梯度放到同一片存储区域。 deepseek_mermaid_20250906_c05028

8.合参UC测试

1.uc下optimizer的格式是

202509161 因此需要做一个格式转换,把斜杠改成.

2.uc下跑dp,只会保存optimizer和master,具体在如下位置:

20250916 20250916

3.uc下跑SD2EP2时,md5未对齐

image-20250919150818139

​ 经过验证,主要原因是,fc下,会把expert,转换成key不同的情况,也就是说,比如有64个expert,分到两个rank上的时候,编号都是0-31,而fc下,我们会把2卡的参数名修改,做数字偏移,比如0变成rank_id*(per_device_expert_nums),从而区别不同rank上的专家,而uc不会,导致对比时,uc上的expert被覆盖,比较出错,修改脚本验证后,对齐。

4.涉及TP的都会报错

追溯原因:

问题1,uc把不同rank的专家当同一个参数合并

image-20250919164903026

image-20250919165034118

image-20250919165203928

image-20250919171846564

在tp下,如果moe_group是tp,则做恒等映射,是不切分专家的。只是均分到整个moe_group中,如下,fc就是ErnieMoeMLP:

image-20250919172146193

image-20250919172220967

但是注意:

image-20250919184456869

image-20250919184445576

image-20250919235758105

​ 这里moe_group已经被parse_moe_group解析成了group格式,即如图,所以这里moe_in_mp始终为false,所以默认所有参数都按tp合并。所以在这里,用一个moe_group_name来提前接收moe_group字符串。

tp合并参数的逻辑如下

image-20250919185730081

然后 action(ret) 执行,但 experts 的 action 是 partial(fn,is_column)列切,因此按列合并。

因此,两个rank上的expert会被错误合并成一个大的tensor。

问题2,未给专家参数设置mp_moe的标志,导致专家被allgather,而实际应该是直接获取本rank的,非本rank的专家参数设置为None

而,当moe_group直接设置为True的时候,action(ret) 执行,但 experts 的 action 是 lambda x: x,直接返回收集到的张量列表,所以这时候返回的tensor就是一个列表,包含rank0的expert tensor和rank1的expert tensor。仍然会导致报错。

image-20250919192302544

注意这里会得到一个tensor列表的主要原因是,如图处丢掉了expert的p.mp_moe的属性,导致expert无法被识别出有mp_moe属性,导致保存时,仍然保存的不是本地的expert,而是一个expert列表,即本来应该走绿色的这条分支,而现在走了红色的这条分支。

image-20250919201944219

问题3,ernie_moe的_get_tensor_parallel_mappings中,未设置mtp_block层的映射,导致在save_ckpt合参时,该参数未被按tp切分维度合并

image-20250920001657477

因此,需要加入如下映射,标记着其处于切分状态:

image-20250920001730397

问题4,load_state_dict和_handle_aoa未考虑到多卡转单机的情况

最后,load_state_dict和_handle_aoa要适配一下多卡转单机的情况,例如加载tp4的ckpt,到单卡时,也需要用到_handle_aoa。

image-20250920000556562

9.AOAEngine学习记录

1.AOAShardInfoContext

image-20250917135642520

​ 这个主要用于记录上下文信息,保留一些信息,给后续操作可调用。

source_state_shard_infodestination_state_shard_info分别表示需要load下来的ckpt对应策略的参数分片信息,和当前正在执行的策略的参数分片信息,格式为_ShardInfo = dict[str, list[ShardedWeightDesc]],即包含,同一个key,再不同rank上的参数分片状态,如果是类似dp这样的,同一个key只会在单个distcp文件中保存,因此只有一个参数分片状态。

​ get_all_dst_state_keys与get_all_src_state_keys则是辅助函数获取其中所有的key,get_num_hidden_layers通过aoa_config中是否配置了$LAYER_ID,来正则匹配dst中所有key中的layer_id,例如下:

"ernie.layers.$LAYER_ID.self_attn.qkv_proj.weight -> ernie.layers.$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv_old, num_heads=20, num_key_value_groups=5"

​ 会以$LAYER_ID为分隔符,分成两份,然后中间以\d匹配,从而匹配到layer_id,遍历所有key,得到的做大ID+1,则为num_hidden_layer的层数。

get_src_state_shard_numget_dst_state_shard_num这两个主要是查看当前key对应参数的分片数,即tp数。

​ 为什么要把optmizer的key也都转换成model的key来算呢,原因是,当做sharding的时候,opt的参数分片数=tp_nums*sharding_nums,直接求就有问题了。

13d81cca-3251-4777-96ff-bf68772a5c2c

2.Lexer

image-20250918153731669

​ 核心目标:为 AoA 表达式做词法分析(Lexing),并在词法分析前先应用已注册的宏展开,最终生成供解析器使用的 token 序列。

​ 首先传入的参数expressions是aoa_conifg["aoa_statements"],这是一个字符串列表,形状如下:

 --aoa_config '{ 
"aoa_statements": [
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0, fused_ffn"
]
}' \

​ 每一个expression会使用apply_macros,即对每个expression,遍历使用所有的已经注册好的macro。

​ 在进入macro之前,会使用tokenize方法将expression解析成多个token,按照token_specification中的正则项进行匹配,name作为key,匹配到的实际内容作为value,比如上述的aoa_config的第一条,首先会根据identifier获取到第一个token:llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight,遇到空格会skip,然后根据rarrow匹配到->,紧接着再根据identifier获取到下一个token,知道最终结束,而每个text都会判断一下后面有没有\n,没有就补充,从而得到NEWLINE,标志着一条text匹配结束。

​ 被所有macro处理后,会得到一个results列表,列表里面也都是expression样子的表达式,最终Lexer会把result_expression再次调用tokenized解析成token返回,给到parser里面做处理。

3.Parser